import torch 
import torch.nn as nn  
from dalle2_pytorch.dalle2_pytorch import RotaryEmbedding, CausalTransformer, SinusoidalPosEmb, MLP, Rearrange, repeat, rearrange, prob_mask_like, LayerNorm, RelPosBias, Attention, FeedForward
import math

class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

    def forward(self, query, key, value):
        attn_out, _ = self.attn(query, key, value)
        return attn_out  # [B, N_q, D]
    

 
class ResBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.f = nn.Sequential(
            nn.Conv2d(c, c, 3, 1, 1, bias=False),
            nn.BatchNorm2d(c),
            nn.ReLU(inplace=True),
            nn.Conv2d(c, c, 3, 1, 1, bias=False),
            nn.BatchNorm2d(c)
        )
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.act(x + self.f(x))


class CentralFoveaAttention(nn.Module):
    def __init__(self, embed_dim=768, grid_size=16):
        super().__init__()
        self.g = grid_size
        hidden = embed_dim // 4

        # 公共 backbone
        self.backbone = nn.Sequential(
            nn.Conv2d(embed_dim, hidden, 3, 1, 1),
            nn.BatchNorm2d(hidden),
            nn.ReLU(inplace=True),
            ResBlock(hidden),
            ResBlock(hidden)
        )

        # 4 个独立 1×1 头
        self.mu_x_conv   = nn.Conv2d(hidden, 1, 1)
        self.mu_y_conv   = nn.Conv2d(hidden, 1, 1)
        self.logsig_conv = nn.Conv2d(hidden, 1, 1)
        self.cls_conv    = nn.Conv2d(hidden, 1, 1)

        self.pool = nn.AdaptiveAvgPool2d(1)

        # 只保存 256 个 patch 的二维坐标 (g*g, 2)
        y, x = torch.meshgrid(
            torch.linspace(-1, 1, grid_size),
            torch.linspace(-1, 1, grid_size),
            indexing='ij')
        grid = torch.stack([x.flatten(), y.flatten()], dim=-1)  # (256,2)
        self.register_buffer('pos', grid)

        # ---------- 权重初始化 ----------
        with torch.no_grad():
            for m in [self.mu_x_conv, self.mu_y_conv, self.cls_conv]:
                nn.init.zeros_(m.weight)
                nn.init.zeros_(m.bias)

            nn.init.zeros_(self.logsig_conv.weight)
            nn.init.zeros_(self.logsig_conv.bias)
            self.logsig_conv.bias.fill_(math.log(0.25))

    def get_mu_w(self, f_img):
        B = f_img.size(0)
        patch = f_img[:, 1:]                                  # (B,256,768)
        feat = patch.permute(0, 2, 1).contiguous().view(B, -1, self.g, self.g)

        feat = self.backbone(feat)                            # (B,hidden,16,16)

        # 四个独立预测
        mu_x   = torch.tanh(self.pool(self.mu_x_conv(feat)).flatten(1))   # (B,1)
        mu_y   = torch.tanh(self.pool(self.mu_y_conv(feat)).flatten(1))   # (B,1)
        log_sigm = self.pool(self.logsig_conv(feat)).flatten(1)             # (B,1)
        cls_logit = self.pool(self.cls_conv(feat)).flatten(1)             # (B,1)

        mu = torch.cat([mu_x, mu_y], dim=1)                   # (B,2)
        log_sigma = torch.clamp(log_sigm.squeeze(1), min=math.log(0.22))
        #log_sigma = log_sigm.squeeze(1)
        sigma = torch.exp(log_sigma).unsqueeze(1)             # (B,1)

        cls_w = torch.sigmoid(cls_logit.squeeze(1)).unsqueeze(1)  # (B,1)

        # patch-wise 高斯权重
        pos_patch = self.pos                                   # (256,2)
        delta = pos_patch.unsqueeze(0) - mu.unsqueeze(1)       # (B,256,2)
        sigma_sq = torch.clamp_min(sigma ** 2, 1e-4)
        gauss = torch.exp(-0.5 * (delta ** 2).sum(dim=2) / sigma_sq)
        # 数值稳定
        gauss = gauss - gauss.max(dim=1, keepdim=True)[0].detach()
        patch_w = F.softmax(gauss, dim=1)                      # (B,256)

        # 拼接 cls + patch 权重
        w = torch.cat([cls_w, patch_w], dim=1)                 # (B,257)

        # 调试
        # print('cls_w:', cls_w.shape, 'patch_w:', patch_w.shape)
        #print('μ (x, y):', mu[0].tolist(), 'σ:', sigma[0, 0].item())

        return mu, w

    def forward(self, f_img):
        mu, w = self.get_mu_w(f_img)
        return f_img * w.unsqueeze(-1)        # (B,257,dim)

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange

# --------------------------------------------------
# 1.  编码器：输出 (B, 257, 768) 的 μ 和 logσ²
# --------------------------------------------------
class VoxelEncoderVAE(nn.Module):
    def __init__(self, num_voxels, token_dim=768, num_tokens=257, hidden_dim=256, n_blocks=2, drop=0.15):
        super().__init__()
        self.num_tokens = num_tokens
        self.token_dim = token_dim

        # 映射到隐藏维度
        self.lin0 = nn.Linear(num_voxels, hidden_dim)

        # Transformer-style 残差块
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(hidden_dim),
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.GELU(),
                    nn.Dropout(drop),
                    nn.Linear(hidden_dim, hidden_dim),
                )
            )
            for _ in range(n_blocks)
        ])

        # 输出 μ 和 logσ²，保持 (B, num_tokens, token_dim)
        self.fc_mu = nn.Linear(hidden_dim, num_tokens * token_dim)
        self.fc_logvar = nn.Linear(hidden_dim, num_tokens * token_dim)

    def forward(self, x):
        """
        x: (B, num_voxels)
        returns:
            mu    : (B, num_tokens, token_dim)
            logvar: (B, num_tokens, token_dim)
        """
        h = self.lin0(x)  # (B, hidden_dim)
        residual = h
        for block in self.blocks:
            h = block(h) + residual
            residual = h

        mu = self.fc_mu(h).view(-1, self.num_tokens, self.token_dim)
        logvar = self.fc_logvar(h).view(-1, self.num_tokens, self.token_dim)
        logvar = torch.tanh(logvar) * 10  # 限制在 [-10, 10]
        return mu, logvar


# --------------------------------------------------
# 2.  重参数化（保持 3-D）
# --------------------------------------------------
def reparameterize(mu, logvar):
    """
    mu, logvar: (B, 257, 768)
    returns: z (B, 257, 768)
    """
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


# --------------------------------------------------
# 3.  解码器：从 (B, 257, 768) 重建体素
# --------------------------------------------------
class VoxelDecoderVAE(nn.Module):
    def __init__(self, num_voxels, token_dim=768, num_tokens=257, hidden_dim=256, n_blocks=2, drop=0.15):
        super().__init__()
        self.num_tokens = num_tokens
        self.token_dim = token_dim

        # 将 token 序列映射回隐藏维度
        self.lin0 = nn.Linear(num_tokens * token_dim, hidden_dim)

        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(hidden_dim),
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.GELU(),
                    nn.Dropout(drop),
                    nn.Linear(hidden_dim, hidden_dim),
                )
            )
            for _ in range(n_blocks)
        ])
        self.lin1 = nn.Linear(hidden_dim, num_voxels)

    def forward(self, z):
        """
        z: (B, 257, 768)
        returns: (B, num_voxels)
        """
        B = z.size(0)
        h = z.view(B, -1)  # (B, 257*768)
        h = self.lin0(h)   # (B, hidden_dim)
        residual = h
        for block in self.blocks:
            h = block(h) + residual
            residual = h
        return self.lin1(h)  # (B, num_voxels)


# --------------------------------------------------
# 4.  VAE 总模型（保持 3-D 隐变量）
# --------------------------------------------------
class VoxelVAE(nn.Module):
    def __init__(self, num_voxels, token_dim=768, num_tokens=257, hidden_dim=256, n_blocks=2, drop=0.15):
        super().__init__()
        self.encoder = VoxelEncoderVAE(num_voxels, token_dim, num_tokens,
                                       hidden_dim, n_blocks, drop)
        self.decoder = VoxelDecoderVAE(num_voxels, token_dim, num_tokens,
                                       hidden_dim, n_blocks, drop)

    def forward(self, x):
        mu, logvar = self.encoder(x)           # (B, 257, 768)
        z = reparameterize(mu, logvar)         # (B, 257, 768)
        recon = self.decoder(z)                # (B, num_voxels)
        return z, recon, mu, logvar
    
    def decode(self, z):
        return self.decoder(z) 